import torch
save_path = "/data/DeepSpeedExamples/Megatron-LM/checkpoints/gpt2_15btest/1000/zero_pp_rank_0_mp_rank_00optim_states.pt"
model_dict = torch.load(save_path)
fp = open('model_parameter.bin', 'wb')
weight_count = 0
num = 1
for k, v in model_dict.items():
    print(k, num)
    num = num + 1
    if 'num_batches_tracked' in k:
        continue
    v = v.cpu().numpy().flatten()
    for d in v:
        fp.write(d)
        weight_count += 1
print('model_weight has Convert Completely!', weight_count)
